{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# standard imports\n",
    "import torch\n",
    "from amdshark.iree_utils import get_iree_compiled_module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "# torch dynamo related imports\n",
    "try:\n",
    "    import torchdynamo\n",
    "    from torchdynamo.optimizations.backends import create_backend\n",
    "    from torchdynamo.optimizations.subgraph import SubGraph\n",
    "except ModuleNotFoundError:\n",
    "    print(\n",
    "        \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
    "    )\n",
    "    exit()\n",
    "\n",
    "# torch-mlir imports for compiling\n",
    "from torch_mlir import compile, OutputType"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "[TorchDynamo](https://github.com/pytorch/torchdynamo) is a compiler for PyTorch programs that uses the [frame evaluation API](https://www.python.org/dev/peps/pep-0523/) in CPython to dynamically modify Python bytecode right before it is executed. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "def toy_example(*args):\n",
    "    a, b = args\n",
    "\n",
    "    x = a / (torch.abs(a) + 1)\n",
    "    if b.sum() < 0:\n",
    "        b = b * -1\n",
    "    return x * b"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "# compiler that lowers fx_graph to through MLIR\n",
    "def __torch_mlir(fx_graph, *args, **kwargs):\n",
    "    assert isinstance(\n",
    "        fx_graph, torch.fx.GraphModule\n",
    "    ), \"Model must be an FX GraphModule.\"\n",
    "\n",
    "    def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):\n",
    "        \"\"\"Replace tuple with tuple element in functions that return one-element tuples.\"\"\"\n",
    "\n",
    "        for node in fx_g.graph.nodes:\n",
    "            if node.op == \"output\":\n",
    "                assert (\n",
    "                    len(node.args) == 1\n",
    "                ), \"Output node must have a single argument\"\n",
    "                node_arg = node.args[0]\n",
    "                if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
    "                    node.args = (node_arg[0],)\n",
    "        fx_g.graph.lint()\n",
    "        fx_g.recompile()\n",
    "        return fx_g\n",
    "\n",
    "    fx_graph = _unwrap_single_tuple_return(fx_graph)\n",
    "    ts_graph = torch.jit.script(fx_graph)\n",
    "\n",
    "    # torchdynamo does munges the args differently depending on whether you use\n",
    "    # the @torchdynamo.optimize decorator or the context manager\n",
    "    if isinstance(args, tuple):\n",
    "        args = list(args)\n",
    "    assert isinstance(args, list)\n",
    "    if len(args) == 1 and isinstance(args[0], list):\n",
    "        args = args[0]\n",
    "\n",
    "    linalg_module = compile(\n",
    "        ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
    "    )\n",
    "    callable, _ = get_iree_compiled_module(\n",
    "        linalg_module, \"cuda\", func_name=\"forward\"\n",
    "    )\n",
    "\n",
    "    def forward(*inputs):\n",
    "        return callable(*inputs)\n",
    "\n",
    "    return forward"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Simplest way to use TorchDynamo with the `torchdynamo.optimize` context manager:"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 1 device(s).\n",
      "Device: 0\n",
      "  Name: NVIDIA GeForce RTX 3080\n",
      "  Compute Capability: 8.6\n",
      "[-0.40066046 -0.4210303   0.03225489 -0.44849953  0.10370405 -0.04422468\n",
      "  0.33262825 -0.20109026  0.02102537 -0.24882983]\n",
      "[-0.07824923 -0.17004533  0.06439921 -0.06163602  0.26633525 -1.1560082\n",
      " -0.06660341  0.24227881  0.1462235  -0.32055548]\n",
      "[-0.01464001  0.442209   -0.0607936  -0.5477967  -0.25226554 -0.08588809\n",
      " -0.30497575  0.00061084 -0.50069696  0.2317973 ]\n",
      "[ 0.25726247  0.39388427 -0.24093066  0.12316308 -0.01981307  0.5661146\n",
      "  0.26199922  0.8123446  -0.01576749  0.30846444]\n",
      "[ 0.7878203  -0.45975062 -0.29956317 -0.07032048 -0.55817443 -0.62506855\n",
      " -1.6837492  -0.38442805  0.28220773 -1.5325156 ]\n",
      "[ 0.07975311  0.67754704 -0.30927914  0.00347631 -0.07326564  0.01893554\n",
      " -0.7518105  -0.03078967 -0.07623022  0.38865626]\n",
      "[-0.7751679  -0.5841397  -0.6622711   0.18574935 -0.6049372   0.02844244\n",
      " -0.20471913  0.3337415  -0.3619432  -0.35087156]\n",
      "[-0.08569919 -0.10775139 -0.02338934  0.21933547 -0.46712473  0.00062137\n",
      " -0.58207744  0.06457533  0.18276742  0.03866556]\n",
      "[-0.2311981  -0.43036282  0.20561649 -0.10363232 -0.13248594  0.02885137\n",
      " -0.31241602 -0.36907142  0.08861586  0.2331427 ]\n",
      "[-0.07273526 -0.31246194 -0.24218291 -0.24145737  0.0364486   0.14382267\n",
      " -0.00531162  0.15447603 -0.5220248  -0.09016377]\n"
     ]
    }
   ],
   "source": [
    "with torchdynamo.optimize(__torch_mlir):\n",
    "    for _ in range(10):\n",
    "        print(toy_example(torch.randn(10), torch.randn(10)))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "It can also be used through a decorator:"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "@create_backend\n",
    "def torch_mlir(subgraph, *args, **kwargs):\n",
    "    assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
    "    return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
    "\n",
    "\n",
    "@torchdynamo.optimize(\"torch_mlir\")\n",
    "def toy_example2(*args):\n",
    "    a, b = args\n",
    "\n",
    "    x = a / (torch.abs(a) + 1)\n",
    "    if b.sum() < 0:\n",
    "        b = b * -1\n",
    "    return x * b"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 1 device(s).\n",
      "Device: 0\n",
      "  Name: NVIDIA GeForce RTX 3080\n",
      "  Compute Capability: 8.6\n",
      "[-0.35494277  0.03409214 -0.02271946  0.7335942   0.03122527 -0.41881397\n",
      " -0.6609761  -0.6418614   0.29336175 -0.01973678]\n",
      "[-2.7246824e-01 -3.5543957e-01  6.0087401e-01 -7.4570496e-03\n",
      " -4.2481605e-02 -5.0296803e-04  7.2928613e-01 -1.4673788e-03\n",
      " -2.7621329e-01 -6.0995776e-02]\n",
      "[-0.03165906  0.3889693   0.24052973  0.27279532 -0.02773128 -0.12602475\n",
      " -1.0124422   0.5720256  -0.35437614 -0.20992722]\n",
      "[-0.41831446  0.5525326  -0.29749998 -0.17044766  0.11804754 -0.05210691\n",
      " -0.46145165 -0.8776549   0.10090438  0.17463352]\n",
      "[ 0.02194221  0.20959911  0.26973712  0.12551276 -0.0020404   0.1490246\n",
      " -0.04456685  1.1100804   0.8105744   0.6676846 ]\n",
      "[ 0.06528181 -0.13591261  0.5370964  -0.4398162  -0.03372452  0.9691372\n",
      " -0.01120087  0.2947028   0.4804801  -0.3324341 ]\n",
      "[ 0.33549032 -0.23001772 -0.08681437  0.16490957 -0.11223086  0.09168988\n",
      "  0.02403045  0.17344482  0.46406478 -0.00129451]\n",
      "[-0.27475086  0.42384806  1.9090122  -0.41147137 -0.6888369   0.08435658\n",
      " -0.26628923 -0.17436793 -0.8058869  -0.02582378]\n",
      "[-0.10109414  0.08681287 -0.10055986  0.6858881   0.29267687 -0.02797117\n",
      " -0.01425194  0.4882803   0.3551982  -0.858935  ]\n",
      "[-0.22086617  0.524994    0.17721705 -0.03813264 -0.54570735 -0.4421502\n",
      "  0.11938014 -0.01122053  0.39294165 -0.61770755]\n"
     ]
    }
   ],
   "source": [
    "for _ in range(10):\n",
    "    print(toy_example2(torch.randn(10), torch.randn(10)))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}